{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MSC 基础 pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm.testing\n",
    "from tvm.relax.frontend.torch import from_fx\n",
    "from tvm.relax import PyExprVisitor\n",
    "\n",
    "from tvm.relay import testing\n",
    "from tvm.relay.expr_functor import ExprVisitor\n",
    "from tvm.relay.build_module import bind_params_by_name\n",
    "\n",
    "from tvm.contrib.msc.core import transform as msc_transform\n",
    "from tvm.contrib.msc.core import utils as msc_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 为 `relax` 测试 `SetExprLayout`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pylint: disable=import-outside-toplevel\n",
    "try:\n",
    "    import torch\n",
    "    import torchvision\n",
    "    from torch import fx\n",
    "except:  # pylint: disable=bare-except\n",
    "    print(\"please install pytorch python package\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelaxLayoutChecker(PyExprVisitor):\n",
    "    \"\"\"检查是否设置了 `name` 作为 `span` 属性。\"\"\"\n",
    "\n",
    "    def check(self, expr):\n",
    "        self._missing_exprs = []\n",
    "        if isinstance(expr, tvm.relax.Expr):\n",
    "            self.visit_expr(expr)\n",
    "        elif isinstance(expr, tvm.relax.BindingBlock):\n",
    "            self.visit_binding_block(expr)\n",
    "        assert len(self._missing_exprs) == 0, f\"Missing {len(self._missing_exprs)} layouts\"\n",
    "\n",
    "    def visit_var_binding_(self, binding) -> None:\n",
    "        super().visit_var_binding_(binding)\n",
    "        if not msc_utils.get_expr_layout(binding.value):\n",
    "            self._missing_exprs.append(binding.value)\n",
    "\n",
    "    def visit_constant_(self, op) -> None:\n",
    "        super().visit_constant_(op)\n",
    "        if not msc_utils.get_expr_layout(op):\n",
    "            self._missing_exprs.append(op)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_model = torchvision.models.resnet50()\n",
    "graph_model = fx.symbolic_trace(torch_model)\n",
    "input_info = [([1, 3, 224, 224], \"float32\")]\n",
    "with torch.no_grad():\n",
    "    mod = from_fx(graph_model, input_info)\n",
    "mod = msc_transform.SetExprLayout()(mod)\n",
    "RelaxLayoutChecker().check(mod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 为 `relay` 测试 `SetExprName`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelayNameChecker(ExprVisitor):\n",
    "    \"\"\"Check if name as span attribute is setted.\"\"\"\n",
    "\n",
    "    def check(self, expr):\n",
    "        self._missing_exprs = []\n",
    "        super().visit(expr)\n",
    "        assert len(self._missing_exprs) == 0, \"Missing {} names\".format(\n",
    "            len(self._missing_exprs)\n",
    "        )\n",
    "\n",
    "    def visit_constant(self, expr):\n",
    "        super().visit_constant(expr)\n",
    "        if not msc_utils.get_expr_name(expr):\n",
    "            self._missing_exprs.append(expr)\n",
    "\n",
    "    def visit_call(self, expr):\n",
    "        super().visit_call(expr)\n",
    "        if not msc_utils.get_expr_name(expr):\n",
    "            self._missing_exprs.append(expr)\n",
    "\n",
    "mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype=\"float32\")\n",
    "mod[\"main\"] = bind_params_by_name(mod[\"main\"], params)\n",
    "mod = msc_transform.SetExprName(as_relax=False)(mod)\n",
    "RelayNameChecker().check(mod[\"main\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 为 `relax` 测试 `SetExprName `"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RelaxNameChecker(PyExprVisitor):\n",
    "    \"\"\"Check if name as span attribute is setted.\"\"\"\n",
    "\n",
    "    def check(self, expr):\n",
    "        self._missing_exprs = []\n",
    "        if isinstance(expr, tvm.relax.Expr):\n",
    "            self.visit_expr(expr)\n",
    "        elif isinstance(expr, tvm.relax.BindingBlock):\n",
    "            self.visit_binding_block(expr)\n",
    "        assert len(self._missing_exprs) == 0, \"Missing {} names\".format(\n",
    "            len(self._missing_exprs)\n",
    "        )\n",
    "\n",
    "    def visit_var_binding_(self, binding) -> None:\n",
    "        super().visit_var_binding_(binding)\n",
    "        if not msc_utils.get_expr_name(binding.value):\n",
    "            self._missing_exprs.append(binding.value)\n",
    "\n",
    "    def visit_constant_(self, op) -> None:\n",
    "        super().visit_constant_(op)\n",
    "        if not msc_utils.get_expr_name(op):\n",
    "            self._missing_exprs.append(op)\n",
    "\n",
    "torch_model = torchvision.models.resnet50()\n",
    "graph_model = fx.symbolic_trace(torch_model)\n",
    "input_info = [([1, 3, 224, 224], \"float32\")]\n",
    "with torch.no_grad():\n",
    "    mod = from_fx(graph_model, input_info)\n",
    "mod = msc_transform.SetExprName()(mod)\n",
    "RelaxNameChecker().check(mod)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
